{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distributed training using torch.distributed.launch module on Azure Machine Learning\n",
    "\n",
    "\n",
    "This example show how to train language model using the huggingface library  distributed on Azure Machine Learning using pytorch estimator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "\n",
    "import wget\n",
    "import os\n",
    "\n",
    "from azureml.core import (Workspace, Experiment, \n",
    "                          VERSION, Datastore)\n",
    "from azureml.core.compute import ComputeTarget, AmlCompute\n",
    "from azureml.core.environment import Environment,CondaDependencies\n",
    "from azureml.train.dnn import PyTorch,Nccl\n",
    "from azureml.data.data_reference import DataReference\n",
    "from azureml.core.compute_target import ComputeTargetException\n",
    "from azureml.widgets import RunDetails\n",
    "\n",
    "SUBSCRIPTION_ID = \"\"\n",
    "RESOURCE_GROUP = \"\"\n",
    "WORKSPACE_NAME = \"\"\n",
    "\n",
    "\n",
    "NUM_NODES = 8\n",
    "NUM_GPU_PER_NODE = 2\n",
    "SKU = 'Standard_NC12'\n",
    "EXP_NAME = 'Azureml-LM_huggingface_example'\n",
    "CLUSTER_NAME = 'two-gpu-cluster'\n",
    "\n",
    "RUN_DIR = os.getcwd()\n",
    "DATA_DIR = 'data'\n",
    "\n",
    "print(\"SDK version:\", VERSION)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ws = Workspace(subscription_id = SUBSCRIPTION_ID, \n",
    "               resource_group =RESOURCE_GROUP , \n",
    "               workspace_name = WORKSPACE_NAME\n",
    "              )\n",
    "\n",
    "\n",
    "exp = Experiment(workspace=ws, name=EXP_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(DATA_DIR, exist_ok=True)\n",
    "wget.download(\"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip\",\n",
    "              out=DATA_DIR\n",
    "             ) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datastore = ws.get_default_datastore()\n",
    "ds_reference = datastore.upload(src_dir='data',\n",
    "                 target_path='wikitext',\n",
    "                 overwrite=True,\n",
    "                 show_progress=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from azureml.core.compute import AmlCompute\n",
    "from azureml.core.compute import ComputeTarget\n",
    "\n",
    "\n",
    "\n",
    "found = False\n",
    "cts = ws.compute_targets\n",
    "if CLUSTER_NAME in cts and cts[CLUSTER_NAME].type == 'AmlCompute':\n",
    "    found = True\n",
    "    print('Found existing compute target.')\n",
    "    compute_target = cts[CLUSTER_NAME]\n",
    "\n",
    "if not found:\n",
    "    print('Creating a new compute target...')\n",
    "    provisioning_config = AmlCompute.provisioning_configuration(vm_size = SKU,max_nodes = NUM_NODES)\n",
    "\n",
    "    # Create the cluster.\\n\",\n",
    "    compute_target = ComputeTarget.create(ws, CLUSTER_NAME, provisioning_config)\n",
    "\n",
    "print('Checking cluster status...')\n",
    "compute_target.wait_for_completion(show_output = True, min_node_count = None, timeout_in_minutes = 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "script_params = {\n",
    "    '--dataset-path':ds_reference.as_mount(),\n",
    "    '--rank':'$AZ_BATCHAI_TASK_INDEX',\n",
    "    '--node_count':NUM_NODES,\n",
    "    '--process_per_node':NUM_GPU_PER_NODE,\n",
    "    '--batch_size':'2'\n",
    "}\n",
    "\n",
    "\n",
    "est = PyTorch(source_directory=RUN_DIR,\n",
    "                pip_packages=['gitpython','scikit-learn','seqeval','tensorboardX',\\\n",
    "                              'tqdm','transformers'],\n",
    "                script_params=script_params,\n",
    "                use_gpu=True,\n",
    "                compute_target=compute_target,\n",
    "                entry_script=os.path.join(RUN_DIR,'train.py'),\n",
    "                framework_version='1.4',\n",
    "                node_count=NUM_NODES,\n",
    "                distributed_training=Nccl()\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run = exp.submit(est)\n",
    "RunDetails(run).show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:amlsamples_env]",
   "language": "python",
   "name": "conda-env-amlsamples_env-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
